# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
# This source file is part of the Cangjie project, licensed under Apache-2.0
# with Runtime Library Exception.
#
# See https://cangjie-lang.cn/pages/LICENSE for license information.

from os import path
from random import sample, choice, randint, seed
from itertools import product
import textwrap

integer_types = ['Int8', 'Int16', 'Int32', 'Int64', 'UInt8', 'UInt16', 'UInt32', 'UInt64']
number_types = integer_types + ['Float16', 'Float32', 'Float64']
integral_types = number_types + ['String', 'Rune', 'Bool', 'Unit', 'Object', 'Tuple', 'Array', 'Range', 'IntNative', 'UIntNative']
default_value_map = {'Bool' : 'true', 'Unit' : '()', 'Rune' : "'1'", 'Object' : 'Object()', 'String' : '"1"', 'C' : 'C()',
'Tuple' : '(1, 1)', 'Array' : '[1]', 'Range' : '1..2', 'Nothing' : 'nothing()', 'IntNative' : 'iN()', 'UIntNative' : 'uN()'}

def tuple_types(t : str):
  return t[t.index('(')+1:t.index(')')].split(', ')

def default_value(ty : str) -> str:
  if ty in integer_types: return '1' + ty[0].lower() + ty[ty.index('t')+1:]
  if 'Float' in ty: return '1.0f' + ty[5:]
  if ty in classes or ty in impls: return ty+'()'
  if ty in interfaces: return 'C%s.%s()' % (ty, ty.lower())
  if '->' in ty:
    return '{%s => (%s)}' %\
    (', '.join(['x%i : %s' % (i, l[i]) for l in [tuple_types(ty)] for i in range(len(l))]), default_value(ty[ty.index('->')+2:]))
  if '(' in ty: return '(%s)' % ', '.join([default_value(t) for t in tuple_types(ty)])
  return default_value_map[ty]

dir = path.dirname(path.realpath(__file__))
path = dir + '/test_' + path.basename(dir) + '_{}.cj'
template = '''
/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This source file is part of the Cangjie project, licensed under Apache-2.0
 * with Runtime Library Exception.
 *
 * See https://cangjie-lang.cn/pages/LICENSE for license information.
 */

/*
  @Assertion:   4.22.2(3)        e1 ~> e2                              |            The lambda expression
                ------------------------------------------------------------------------------------------
                The type f implements the single-parameter operator () | let f = e1; let g = e2;
                overloading function, and g is a function type, and    | {x => g(f.operator()(x))}
                the return value type of f.operator() is a subtype of  |
                the argument type of g                                 |
  @Description: %s
  @Mode: %s
  @Negative: %s
  @Structure: complex-main
  @Dependencies: ../aux_decl.cj
  @CompileWarning: %s
  @Comment: Auto-generated by gen.py%s
*/
%s

main() {
    %s
}
'''
aux_template = '''let f{i} = F{i}()
    let g{i} = {{ arg : {} => logStr += " ~> g{i})" }}
    logStr = ""
    (f{i} ~> g{i})("1")
    Assert.equals("1 |> (f{i} ~> g{i})", logStr)
    logStr = ""
    {{ x : String => g{i}(f{i}(x)) }}("1")
    Assert.equals("1 |> (f{i} ~> g{i})", logStr)'''
decl = '''
class F%s {
    operator func ()(arg : String) { %s }
}
'''

seed(123)
limit = 100
classes = { 'A' : 'JI', 'B' : 'AJI', 'C' : 'BAJI', 'D' : 'AJI'}
interfaces = { 'I' : '', 'J' : 'I', 'K' : 'JI', 'L' : 'I'}
impls = { 'CI' : 'I', 'CJ' : 'JI', 'CK' : 'KJI', 'CL' : 'LI'}
custom_type_dict_list = [classes, interfaces, impls]
custom_type_dict = {**classes, **interfaces, **impls}
custom_types = [t for c in custom_type_dict_list for t in c]
types = integral_types + custom_types

def subtype(t0, t1):
  if isinstance(t0, tuple): return all([subtype(t0[i], t1[i]) for i in range(len(t0))])
  return t0 == t1 or any(t0 in c and t1 in c[t0] for c in [classes, interfaces, impls])

def prod_list(s1, s2):
  return list(product(s1, s2))

def write_counted(contents : str):
  global counter
  with open(path.format(str(counter).zfill(3)), 'w') as file:
    file.write(contents)
    counter += 1

subtypes = [(t, t) for t in integral_types] + [('Nothing', t) for t in types] + [('A', 'Object')]
non_subtypes = sample([(t0, t1) for t0, t1 in product(integral_types, integral_types) if t0 != t1], limit)
subtypes += [('(A, B, C)', '(%s, %s, %s)' % (t0, t1, t2))
             for t0, t1, t2 in product(classes, repeat=3) if subtype(('A', 'B', 'C'), (t0, t1, t2))]
non_subtypes += [('(A, B, C)', '(%s, %s, %s)' % (t0, t1, t2))
             for t0, t1, t2 in product(classes, repeat=3) if not subtype(('A', 'B', 'C'), (t0, t1, t2))]
subtypes += [('(A, B) -> (A, C)', '(%s, %s) -> (%s, %s)' % (t0, t1, t2, t3))
    for t0, t1, t2, t3 in product(classes, repeat=4) if subtype((t0, t1), ('A', 'B')) and subtype(('A', 'C'), (t2, t3))]
non_subtypes += [('(A, B) -> (A, C)', '(%s, %s) -> (%s, %s)' % (t0, t1, t2, t3))
    for t0, t1, t2, t3 in product(classes, repeat=4) if not (subtype((t0, t1), ('A', 'B')) and subtype(('A', 'C'), (t2, t3)))]

for c in [prod_list(interfaces, interfaces), prod_list(classes, classes), prod_list(impls, interfaces), prod_list(classes, interfaces)]:
    subtypes += [(t0, t1) for t0, t1 in c if subtype(t0, t1)]
    non_subtypes += [(t0, t1) for t0, t1 in c if not subtype(t0, t1)]

counter = 1
positive_test = template % ('''Checks that composition f ~> g is permitted when return type T1 of f is a subtype of argument type T2
                of g and for various T1 and T2 and behaves like { x => g(f(x)) }.''', 'run', 'no', 'ignore', '\n  @Issue: 6523', '''
from utils import utils.assert.Assert

var logStr = ""
''' + ''.join([decl % (i , 'logStr += arg + " |> (f%s"; %s' % (i, default_value(subtypes[i][0]))) for i in range(len(subtypes))]),
'\n\n    '.join([aux_template.format(subtypes[i][1], i=i) for i in range(len(subtypes))]))
write_counted(positive_test)
negative_tests = [template % ('''Checks that composition f ~> g is not permitted when return type of f is %s
                and argument type of g is %s.''' % (t1, t0) , 'compileonly', 'yes', 'no', '',
    decl % ('', default_value(t0)), '''let f = F()
    let g = { arg : %s => () }
    (f ~> g)("1")''' % t1) for t0, t1 in non_subtypes]
for test in negative_tests:
  write_counted(test)
